﻿namespace HZY.Framework.Aop.Attributes;

/// <summary>
/// aop 拦截器
/// </summary>
public class AopMoAttribute : MoAttribute
{
    static IHost? Host = null;
    static readonly AsyncLocal<IServiceProvider> ServiceProvider = new();
    static readonly AsyncLocal<IServiceScope?> Scope = new();

    private static readonly object _lock = new();

    /// <summary>
    /// 设置主机
    /// </summary>
    /// <param name="host"></param>
    public static void SetHost(IHost host)
    {
        if (Host is not null) return;
        Host = host;
    }

    /// <summary>
    /// 设置服务提供者
    /// </summary>
    /// <param name="serviceProvider"></param>
    public static void SetServiceProvider(IServiceProvider serviceProvider) => ServiceProvider.Value = serviceProvider;

    /// <summary>
    /// 创建域服务
    /// </summary>
    /// <returns></returns>
    protected static IServiceScope? GetScope()
    {
        lock (_lock)
        {
            if (Scope.Value is not null)
            {
                Scope.Value.Dispose();
                Scope.Value = null;
            }

            Scope.Value ??= Host?.Services?.CreateScope();
            return Scope.Value;
        }
    }

    /// <summary>
    /// 获取服务
    /// </summary>
    /// <param name="type"></param>
    /// <param name="context"></param>
    /// <returns></returns>
    protected object? GetService(MethodContext context, Type type)
    {
        try
        {
            var serviceProvider = GetServiceProvider(context);
            if (serviceProvider != null)
            {
                var service = serviceProvider.GetService(type);
                if (service is not null)
                {
                    return service;
                }
            }

            if (ServiceProvider.Value != null)
            {
                return ServiceProvider.Value.GetService(type);
            }

            return GetScope()?.ServiceProvider.GetService(type);
        }
        catch (Exception ex)
        {
            using var scopeLogger = Host?.Services?.CreateScope();
            var logger = scopeLogger?.ServiceProvider.GetService<ILogger<AopMoAttribute>>();
            logger?.LogError(ex, ex.Message);

            return GetScope()?.ServiceProvider.GetService(type);
        }
    }

    /// <summary>
    /// 获取服务
    /// </summary>
    /// <typeparam name="TService"></typeparam>
    /// <returns></returns>
    protected TService? GetService<TService>(MethodContext context)
    {
        try
        {
            var serviceProvider = GetServiceProvider(context);
            if (serviceProvider != null)
            {
                var service = serviceProvider.GetService<TService>();
                if (service is not null)
                {
                    return service;
                }
            }

            if (ServiceProvider.Value != null)
            {
                return ServiceProvider.Value.GetService<TService>();
            }

            var scope = GetScope();
            return scope is null ? default : scope!.ServiceProvider.GetService<TService>();
        }
        catch (Exception ex)
        {
            using var scopeLogger = Host?.Services?.CreateScope();
            var logger = scopeLogger?.ServiceProvider.GetService<ILogger<AopMoAttribute>>();
            logger?.LogError(ex, ex.Message);

            var scope = GetScope();
            return scope is null ? default : scope!.ServiceProvider.GetService<TService>();
        }
    }

    /// <summary>
    /// 获取日志对象
    /// </summary>
    /// <typeparam name="T"></typeparam>
    /// <param name="context"></param>
    /// <returns></returns>
    protected ILogger? GetLogger<T>(MethodContext context)
    {
        // 如果目标类型空
        if (context.TargetType == null)
        {
            // 获取日志对象
            var type = typeof(ILogger<>).MakeGenericType(typeof(T));
            return GetService(context, type) as ILogger;
        }

        // 如果是抽象类
        if (context.TargetType.IsAbstract)
        {
            // 获取日志对象
            var type = typeof(ILogger<>).MakeGenericType(typeof(T));
            return GetService(context, type) as ILogger;
        }

        // 如果是泛型类
        if (context.TargetType.IsGenericType)
        {
            // 获取日志对象
            var type = typeof(ILogger<>).MakeGenericType(typeof(T));
            return GetService(context, type) as ILogger;
        }

        try
        {
            // 获取日志对象
            var type = typeof(ILogger<>).MakeGenericType(context.TargetType!);
            return GetService(context, type) as ILogger;
        }
        catch (Exception ex)
        {
            using var scopeLogger = Host?.Services?.CreateScope();
            var logger = scopeLogger?.ServiceProvider.GetService<ILogger<AopMoAttribute>>();
            logger?.LogError(ex, ex.Message);

            // 获取日志对象
            var type = typeof(ILogger<>).MakeGenericType(typeof(T));
            return GetService(context, type) as ILogger;
        }
    }

    /// <summary>
    /// 获取 IServiceProvider
    /// </summary>
    /// <param name="context"></param>
    /// <param name="type"></param>
    /// <returns></returns>
    protected virtual IServiceProvider? GetServiceProvider(MethodContext context, Type? type = null)
    {
        // 检测类型是否继承了 IAutowiredServiceProvider
        if (context.Target is IAopServiceProvider autowiredServiceProvider)
        {
            return autowiredServiceProvider.ServiceProvider;
        }

        var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static;

        IServiceProvider? serviceProvider = null;

        try
        {
            type ??= context.Target?.GetType();

            if (type is null) return serviceProvider;

            foreach (var item in type.GetProperties(flags))
            {
                if (item.PropertyType == typeof(IServiceProvider))
                {
                    serviceProvider = item.GetValue(context.Target) as IServiceProvider;
                    break;
                }
            }

            foreach (var item in type.GetFields(flags))
            {
                if (item.FieldType == typeof(IServiceProvider))
                {
                    serviceProvider = item.GetValue(context.Target) as IServiceProvider;
                    break;
                }
            }

            if (serviceProvider is not null)
            {
                return serviceProvider;
            }

            if (type.BaseType is null)
            {
                return serviceProvider;
            }

            return GetServiceProvider(context, type.BaseType);
        }
        catch (Exception)
        {
            return serviceProvider;
        }
    }

}